#!/usr/bin/env python
# -*- coding: utf8 -*-

import sys

# REverse compiler for ST music modules MUZ/ST7/ST8.
# (c) 2011 by Mono

log = False

# Wyjatek rzucacny kiedy plik nie jest plikiem binarnym
class NotBinaryFileException (Exception):
	def __init__(self, file):
		self.file = file

# Wyjatek rzucany kiedy w pliku nie ma juz blokow binarnych
class NoMoreDataException (Exception):
	def __init__(self, file):
		self.file = file

# Blok pliku binarnego
class BinaryBlock:
	def __init__(self, file, requireHeader = False):
		b2w = lambda a, b: a+256*b
		
		buf = file.read(2)
		if len(buf) == 0:
			raise NoMoreDataException(file)
		arr = bytearray(buf)
		if arr[0] != 0xff and arr[1] != 0xff:
			if requireHeader:
				raise NotBinaryFileException(file)
		else:
			buf = file.read(2)
		arr = bytearray(buf)
		self.first = b2w(arr[0], arr[1])
		buf = file.read(2)
		arr = bytearray(buf)
		self.last = b2w(arr[0], arr[1])
		buf = file.read(self.last + 1 - self.first)
		self.content = bytearray(buf)

	def __str__(self):
		return "$%04x..$%04x" % (self.first, self.last)

# Plik binarny
class BinaryFile:
	def __init__(self, filename):
		self.blocks = []
		file = open(filename, "rb")
		try:
			chunk = BinaryBlock(file)	#BinaryBlock(file, True)
			self.blocks.append(chunk)
			while True:
				chunk = BinaryBlock(file)
				self.blocks.append(chunk)
		except NoMoreDataException:
			pass
		finally:
			file.close()

	def __str__(self):
		return self.blocks

class CompiledTrack:
	def __init__(self, loop, data):
		self.loop = loop
		self.data = data
	
	def __str__(self):
		return "$%02x, %s" % (self.loop, self.data)

class CompiledInstrument:
	def __init__(self, volumeEnvelopeLoop, frequencyEnvelopeLoop, distortion, reserved):
		self.volumeEnvelopeLoop = volumeEnvelopeLoop
		self.frequencyEnvelopeLoop = frequencyEnvelopeLoop
		self.distortion = distortion
		self.reserved = reserved
	
	def __str__(self):
		return "$%02x, $%02x, $%02x, $%02x" % (self.volumeEnvelopeLoop, self.frequencyEnvelopeLoop, self.distortion, self.reserved)

class CompiledPattern:
	def __init__(self, data):
		self.data = data
	
	def __str__(self):
		return "%s" % self.data

class CompiledEnvelope:
	def __init__(self, data):
		self.data = data
	
	def __str__(self):
		return "%s" % self.data

# Skompilowany modul ST
class CompiledModule:
	def __init__(self, address, data):
		self.tracks = []
		self.patterns = []
		self.volumeEnvelopes = []
		self.frequencyEnvelopes = []
		self.instruments = []
		
		b2w = lambda a, b: a+256*b

		patternAddressTableOffset = b2w(data[12], data[13]) - address
		volumeEnvelopeAddressTableOffset = b2w(data[14], data[15]) - address
		frequencyEnvelopeAddressTableOffset = b2w(data[16], data[17]) - address
		instrumentsTableOffset = b2w(data[18], data[19]) - address
		
		offsets = []
		findNext = lambda addr: offsets[offsets.index(addr)+1]
		
		offsets.append(patternAddressTableOffset)
		offsets.append(volumeEnvelopeAddressTableOffset)
		offsets.append(frequencyEnvelopeAddressTableOffset)
		offsets.append(instrumentsTableOffset)
		offsets.append(len(data))
		tracksCount = 4
		for n in range(0, tracksCount):
			offset = b2w(data[4+n], data[8+n]) - address
			offsets.append(offset)
		offsets = sorted(offsets)
		patternsCount = (findNext(patternAddressTableOffset)-patternAddressTableOffset)/2
		for n in range(0, patternsCount):
			offset = b2w(data[patternAddressTableOffset+n*2], data[patternAddressTableOffset+n*2+1]) - address
			offsets.append(offset)
		offsets = sorted(offsets)
		volumeEnvelopesCount = (findNext(volumeEnvelopeAddressTableOffset)-volumeEnvelopeAddressTableOffset)/2
		for n in range(0, volumeEnvelopesCount):
			offset = b2w(data[volumeEnvelopeAddressTableOffset+n*2], data[volumeEnvelopeAddressTableOffset+n*2+1]) - address
			offsets.append(offset)
		offsets = sorted(offsets)
		frequencyEnvelopesCount = (findNext(frequencyEnvelopeAddressTableOffset)-frequencyEnvelopeAddressTableOffset)/2
		for n in range(0, frequencyEnvelopesCount):
			offset = b2w(data[frequencyEnvelopeAddressTableOffset+n*2], data[frequencyEnvelopeAddressTableOffset+n*2+1]) - address
			offsets.append(offset)
		offsets = sorted(offsets)
		instrumentsCount = (findNext(instrumentsTableOffset)-instrumentsTableOffset)/4
		
		#print ["$%04x" % addr for addr in offsets]
		print "Tracks: %i" % tracksCount
		for n in range(0, tracksCount):
			loop = data[0+n]
			startOffset = b2w(data[4+n], data[8+n]) - address
			stopOffset = findNext(startOffset)
			for o in range(startOffset, stopOffset):
				if data[o] == 0xfd or data[o] == 0xff:
					o = o+1
				elif data[o] == 0xff:
					if o+1 != stopOffset:
						stopOffset = o+1
						print "Bad track %i length - shrinking data at $%04x" % (n+1, o+1)
						break
			#print "$%04x:$%04x" % (startOffset, stopOffset)
			content = data[startOffset:stopOffset]
			self.tracks.append( CompiledTrack(loop, content) )
		
		print "Patterns: %i" % patternsCount
		for n in range(0, patternsCount):
			startOffset = b2w(data[patternAddressTableOffset+n*2], data[patternAddressTableOffset+n*2+1]) - address
			stopOffset = findNext(startOffset)
			for o in range(startOffset, stopOffset):
				if data[o] == 0xff:
					if o+1 != stopOffset:
						stopOffset = o+1
						print "Bad pattern %i length - shrinking data at $%04x" % (n+1, o+1)
						break
			#print "$%04x:$%04x" % (startOffset, stopOffset)
			content = data[startOffset:stopOffset]
			self.patterns.append( CompiledPattern(content) )
		
		print "Volume envelopes: %i" % volumeEnvelopesCount
		for n in range(0, volumeEnvelopesCount):
			startOffset = b2w(data[volumeEnvelopeAddressTableOffset+n*2], data[volumeEnvelopeAddressTableOffset+n*2+1]) - address
			stopOffset = findNext(startOffset)
			for o in range(startOffset, stopOffset):
				if data[o] == 0xff:
					if o+1 != stopOffset:
						stopOffset = o+1
						print "Bad volume envelope %i length - shrinking data at $%04x" % (n+1, o+1)
						break
			#print "$%04x:$%04x" % (startOffset, stopOffset)
			content = data[startOffset:stopOffset]
			self.volumeEnvelopes.append( CompiledEnvelope(content) )
		
		print "Frequency envelopes: %i" % frequencyEnvelopesCount
		for n in range(0, frequencyEnvelopesCount):
			startOffset = b2w(data[frequencyEnvelopeAddressTableOffset+n*2], data[frequencyEnvelopeAddressTableOffset+n*2+1]) - address
			stopOffset = findNext(startOffset)
			for o in range(startOffset, stopOffset):
				if data[o] == 0xff:
					if o+1 != stopOffset:
						stopOffset = o+1
						print "Bad frequency envelope %i length - shrinking data at $%04x" % (n+1, o+1)
						break
			#print "$%04x:$%04x" % (startOffset, stopOffset)
			content = data[startOffset:stopOffset]
			self.frequencyEnvelopes.append( CompiledEnvelope(content) )
		
		print "Instruments: %i" % instrumentsCount
		for n in range(0, instrumentsCount):
			volumeEnvelopeLoop = data[instrumentsTableOffset+n*4]
			frequencyEnvelopeLoop = data[instrumentsTableOffset+n*4+1]
			distortion = data[instrumentsTableOffset+n*4+2]
			reserved = data[instrumentsTableOffset+n*4+3]
			self.instruments.append( CompiledInstrument(volumeEnvelopeLoop, frequencyEnvelopeLoop, distortion, reserved) )
	
	def __str__(self):
		return "%s, %s, %s, %s, %s" % (self.tracks, self.patterns, self.volumeEnvelopes, self.frequencyEnvelopes, self.instruments)

class SourceSong:
	def __init__(self, loop, data):
		self.loop = loop
		self.data = data

class SourceTrack:
	def __init__(self, data):
		self.data = data

class SourcePattern:
	def __init__(self, length, tempo, audctl, tracks):
		self.length = length
		self.tempo = tempo
		self.audctl = audctl
		self.tracks = tracks

class SourceInstrument:
	def __init__(self, volumeEnvelope, frequencyEnvelope, distortion, reserved):
		self.volumeEnvelope = volumeEnvelope
		self.frequencyEnvelope = frequencyEnvelope
		self.distortion = distortion
		self.reserved = reserved

class SourceEnvelope:
	def __init__(self, loop, data):
		self.loop = loop
		self.data = data

# Modul ST w postaci zrodlowej
class SourceModule:
	def __init__(self, instruments, patterns, song):
		self.instruments = instruments
		self.patterns = patterns
		self.song = song
	
	def __str__(self):
		return "%s, %s, %s" % (self.instruments, self.patterns, self.song)

	def write(self, filename):
		file = open(filename, "wb")
		try:
			file.write("Music ")
			
			print "Instruments: %i" % len(self.instruments)
			file.write(chr(len(self.instruments)))
			for n in range(0, len(self.instruments)):
				#print "#%d:" % n
				file.write(chr(n))
				file.write(chr(len(self.instruments[n].volumeEnvelope.data)))
				file.write(chr(self.instruments[n].volumeEnvelope.loop))
				file.write(chr(self.instruments[n].distortion))
				file.write(self.instruments[n].volumeEnvelope.data)
				file.write(chr(len(self.instruments[n].frequencyEnvelope.data)))
				file.write(chr(self.instruments[n].frequencyEnvelope.loop))
				file.write(chr(self.instruments[n].reserved))
				file.write(self.instruments[n].frequencyEnvelope.data)
			
			print "Patterns: %i" % len(self.patterns)
			file.write(chr(len(self.patterns)))
			for n in range(0, len(self.patterns)):
				#print "#%d:" % n
				file.write(chr(n))
				file.write(chr(self.patterns[n].length))
				file.write(chr(self.patterns[n].tempo))
				file.write(chr(self.patterns[n].audctl))
				for t in range(0, 4):
					#print len(self.patterns[n].tracks[t].data)
					file.write(chr(len(self.patterns[n].tracks[t].data)))
					file.write(self.patterns[n].tracks[t].data)
			
			print "Song: %i" % len(self.song.data)
			file.write(chr(self.song.loop))
			file.write(chr(len(self.song.data)))
			file.write(self.song.data)
			
			file.flush()
		finally:
			file.close()

# Konwerter z postacji skompilowanej na zrodlowa
class Converter:
	def convert(self, module):
		songPatterns = []
		trackPositions = bytearray(len(module.tracks))
		trackLoopCorrections = bytearray(len(module.tracks))
		patternTempo = 0
		patternAudctl = 0
		patternTracks = bytearray(len(module.tracks))
		patternLengths = bytearray(len(module.tracks))
		trackCommands = bytearray(len(module.tracks))
		while trackCommands[0] != 0xff or trackCommands[1] != 0xff or trackCommands[2] != 0xff or trackCommands[3] != 0xff: 
			for n in range(0, len(module.tracks)):
				trackPosition = trackPositions[n]
				trackLoopCorrection = trackLoopCorrections[n]
				while True:
					trackCommand = module.tracks[n].data[trackPosition]
					trackCommands[n] = trackCommand
					if trackCommand == 0xff:
						break
					elif trackCommand == 0xfe:
						if module.tracks[n].loop > trackPosition:
							trackLoopCorrection = trackLoopCorrection+2
						patternTempo = module.tracks[n].data[trackPosition+1]
						trackPosition = trackPosition+2
					elif trackCommand == 0xfd:
						if module.tracks[n].loop > trackPosition:
							trackLoopCorrection = trackLoopCorrection+2
						patternAudctl = module.tracks[n].data[trackPosition+1]
						#print "%i:%i:$%02x" % (n, trackPosition, patternAudctl)
						trackPosition = trackPosition+2
					else:
						patternTracks[n] = trackCommand
						patternLength = 0
						patternDelta = 0
						lastNote = False
						#print "Pattern for #%i" % n
						for p in range(0, len(module.patterns[trackCommand].data)):
							patternCommand = module.patterns[trackCommand].data[p]
							if patternCommand == 0xff:
								#print "END"
								if lastNote:
									patternLength = patternLength+patternDelta
								break
							elif patternCommand < 0x40:
								#print "Note: $%02x" % patternCommand
								patternLength = patternLength+1
								if lastNote:
									patternLength = patternLength+patternDelta
								lastNote = True
							elif patternCommand >= 0x80:
								patternDelta = patternCommand & 0x7f
								#print "Duration: $%02x" % patternDelta
								patternLength = patternLength+patternDelta
								lastNote = False
							#else:
								#print "Instrument: $%02x" % (patternCommand & 0x3f)
						patternLengths[n] = patternLength
						trackPosition = trackPosition+1
						break
				trackPositions[n] = trackPosition
				trackLoopCorrections[n] = trackLoopCorrection
			#print "On track patterns lengths: $%02x,$%02x,$%02x,$%02x" % (patternLengths[0], patternLengths[1], patternLengths[2], patternLengths[3])
			patternLength = max(patternLengths)
			songEntry = tuple( [ patternLength, patternTempo, patternAudctl, patternTracks[0], patternTracks[1], patternTracks[2], patternTracks[3] ] )
			if log: print "%03i: " % len(songPatterns) + "L:$%02x, T:$%02x, A:$%02x, $%02x, $%02x, $%02x, $%02x" % songEntry
			songPatterns.append(songEntry)
		
		songLoop = module.tracks[0].loop-trackLoopCorrections[0]
		songData = bytearray()
		uniquePatterns = []
		for p in range(0, len(songPatterns)):
			if songPatterns[p] in uniquePatterns:
				patternIndex = uniquePatterns.index(songPatterns[p])
				songData.append(patternIndex)
			else:
				songData.append(len(uniquePatterns))
				uniquePatterns.append(songPatterns[p])
		song = SourceSong(songLoop, songData)
		
		patterns = []
		for p in range(0, len(uniquePatterns)):
			uniquePattern = uniquePatterns[p]
			patternLength = uniquePattern[0]
			patternTempo = uniquePattern[1]
			patternAudctl = uniquePattern[2]
			patternTracks = uniquePattern[3:7]
			sourceTracks = []
			for t in range(0, len(patternTracks)):
				patternTrack = module.patterns[patternTracks[t]]
				sourceTracks.append( SourceTrack(patternTrack.data[:len(patternTrack.data)-1]) )
			patterns.append( SourcePattern(patternLength, patternTempo, patternAudctl, sourceTracks) )

		instruments = []
		for i in range(0, len(module.instruments)):
			instrument = module.instruments[i]
			instrumentDistortion = instrument.distortion
			instrumentReserved = instrument.reserved
			instrumentVolumeEnvelopeLoop = instrument.volumeEnvelopeLoop+1
			instrumentVolumeEnvelope = module.volumeEnvelopes[i]
			instrumentFrequencyEnvelopeLoop = instrument.frequencyEnvelopeLoop+1
			instrumentFrequencyEnvelope = module.frequencyEnvelopes[i]
			volumeEnvelope = SourceEnvelope(instrumentVolumeEnvelopeLoop, instrumentVolumeEnvelope.data[:len(instrumentVolumeEnvelope.data)-1])
			frequencyEnvelope = SourceEnvelope(instrumentFrequencyEnvelopeLoop, instrumentFrequencyEnvelope.data[:len(instrumentFrequencyEnvelope.data)-1])
			instruments.append( SourceInstrument(volumeEnvelope, frequencyEnvelope, instrumentDistortion, instrumentReserved) )

		return SourceModule(instruments, patterns, song)

# Main entrance
if __name__ == "__main__":
	progname = sys.argv[0]
	if len(sys.argv) > 1:
		modname = sys.argv[1]
		if len(sys.argv) > 2:
			sourcename = sys.argv[2]
		else:
			dotindex = modname.rfind(".")
			if dotindex > 0:
				sourcename = modname[:dotindex]
			else:
				sourcename = modname
			sourcename = sourcename + ".MUZ"
		print "Reading %s compiled file" % modname
		binfile = BinaryFile(modname)
		if len(binfile.blocks) > 0:
			if (len(binfile.blocks) > 1):
				print "File has %i blocks - get only first" % len(binfile.blocks)
			block = binfile.blocks[0]
			module = CompiledModule(block.first, block.content)
			converter = Converter()
			source = converter.convert(module)
			print "Writing %s source file" % sourcename
			source.write(sourcename)
		else:
			print "No blocks in file"
	else:
		print "REverse compiler for ST 07/08 music modules (ST7/ST8/MUZ)."
		print "(c) 2011 by Mono / Tristesse"
		print "\nUsage: %s compilename [sourcename]" % progname
		print "\nGenerates .MUZ source file from compiled file."

